import numpy as np
import pandas as pd
import os
from collections import Counter

pd.set_option('display.width', None)

pd.set_option("display.max_rows", 1000)  # 可显示1000行
pd.set_option("display.max_columns", 1000)  # 可显示1000列

pd.set_option('display.max_rows', None)  # 显示全部行
pd.set_option('display.max_columns', None)  # 显示全部列

pd.options.display.max_rows = 300  # 显示300行

# 恢复默认
# pd.reset_option("display.max_rows")  # 恢复默认设置


def get_illegal_ids_by_inter_num(inter_feat,field,min_num):

    ids = inter_feat[field].values
    inter_num = Counter(ids)
    ids = {id_ for id_ in inter_num if inter_num[id_] < min_num}
    print('[{}] illegal_ids_by_inter_num, field=[{}]'.format(len(ids), field))
    return ids

def data_generate(path):
    data_name = 'ITE_kuaishou'
    curPath = os.path.abspath(os.path.dirname(__file__))
    rootPath = curPath[:curPath.find('UITE') + len('UITE')]
    data_path = os.path.join(rootPath, 'dataset')
    data_path = os.path.join(data_path, data_name)

    open_file_path = os.path.join(path, 'ITE_kuaishou.csv')
    # open_file_path = os.path.join(path, 'test.csv')
    save_file_path = os.path.join(data_path, 'ITE_kuaishou.csv')


    uid_field = "a.user_id"
    iid_field = "photo_id"
    label_columns = ["a.user_age_segment","city_name",
                     "hetu_level_two_tag","a.user_id",'photo_id']

    inter_feat = pd.read_csv(open_file_path).drop(labels='author_id',axis=1)
    print('Interaction num: {}'.format(len(inter_feat)))
    print('Treatment ratio: {}%'.format(round(100 * sum(inter_feat['is_living']) / len(inter_feat))))

    inter_feat = inter_feat.fillna(method="ffill")

    # ban_users = get_illegal_ids_by_inter_num(inter_feat,"user_id",5)
    # ban_items = get_illegal_ids_by_inter_num(inter_feat,"photo_id",5)
    # dropped_inter = pd.Series(False, index=inter_feat.index)
    # if len(ban_users) != 0:
    #     dropped_inter |= inter_feat[uid_field].isin(ban_users)
    # if len(ban_items) !=0:
    #     dropped_inter |= inter_feat[iid_field].isin(ban_items)
    # inter_feat.drop(inter_feat.index[dropped_inter],inplace=True)
    # inter_feat.reset_index(inplace=True,drop=True)
    #
    # print('[{}-{}={}] dropped interactions'.format(len(dropped_inter),sum(dropped_inter),len(inter_feat)))

    dropped_inter = inter_feat.index[inter_feat['is_living'] == 0]
    # dropped_index = np.random.randint(0,len(dropped_inter),int(len(dropped_inter) * 1),)
    dropped_inter = np.random.choice(dropped_inter,size= int(len(dropped_inter) * 0.5),replace=False)
    inter_feat.drop(dropped_inter, inplace=True)
    inter_feat.reset_index(inplace=True,drop=True)
    print('Dropped interaction num: {}'.format(len(inter_feat)))
    print('Dropped treatment ratio: {}%'.format(round(100 * sum(inter_feat['is_living']) / len(inter_feat))))
    inter_feat["play_time_ms"] = inter_feat["play_time_ms"].apply(lambda x:x/1000)
    inter_feat["user_avg_play_time_ms"] = inter_feat["user_avg_play_time_ms"].apply(lambda x: x / 1000)
    inter_feat["photo_avg_play_time_ms"] = inter_feat["photo_avg_play_time_ms"].apply(lambda x: x / 1000)

    photo_click_probability = inter_feat['photo_exp_click'] / inter_feat['photo_exp_show']
    user_play_probability = inter_feat['user_play_cnt_thirtyd'] / inter_feat['user_play_cnt_thirtyd'].max()

    control_photo = inter_feat['photo_avg_play_time_ms'].values * (1-photo_click_probability.values)
    control_user = inter_feat['user_avg_play_time_ms'].values * (1-user_play_probability.values)

    Y_0 = (control_user + control_photo ) / 2.
    treat_photo = inter_feat['photo_avg_play_time_ms'].values * photo_click_probability.values
    treat_user = inter_feat['user_avg_play_time_ms'].values * user_play_probability.values

    Y_1 = (treat_photo + treat_user ) / 2.


    ycf = np.where(inter_feat['is_living']==1,Y_0,Y_1)
    yf = inter_feat["play_time_ms"].values

    mu1 = np.where(inter_feat['is_living']==0,ycf,yf)
    mu0 = np.where(inter_feat['is_living']==1,ycf,yf)

    for label in label_columns:
        inter_feat[label] = pd.factorize(inter_feat[label])[0]

    inter_feat['hetu_level_one_tag'] = inter_feat['hetu_level_one_tag'].astype(int)

    for column in inter_feat.columns:
        if column not in [uid_field,iid_field,'play_time_ms']:
            inter_feat[column] = (inter_feat[column]-inter_feat[column].min()) / (inter_feat[column].max()-inter_feat[column].min())
            print(column,inter_feat[column].max(),inter_feat[column].min(),inter_feat[column].dtype)
            # tt = inter_feat[column]-inter_feat[column].min()
            # ww = inter_feat[column].max()-inter_feat[column].min()
            # print(column,tt.max(),tt.min(),ww)
    rename_columns = {
        uid_field: "user_id",
        "is_living":"treatment",
        "play_time_ms":"yf",
        "photo_id":"item_id",
        "user_vv_count": "x0",
        "a.user_gender":"x1",
        "a.user_age_segment":"x2",
        "a.user_level":"x3",
        "city_name":"x4",
        "user_play_cnt_thirtyd":"x5",
        "user_avg_play_time_ms": "x6",
        "photo_exp_show":"v0",
        "photo_exp_click":"v1",
        "hetu_level_one_tag":"v2",
        "hetu_level_two_tag":"v3",
        "duration_ms":"v4",
        "photo_avg_play_time_ms":"v5"
    }
    inter_feat.rename(columns=rename_columns,inplace=True)
    ycf = np.zeros(len(inter_feat))

    inter_feat.insert(loc=2,column='mu1',value = mu1)
    inter_feat.insert(loc=3, column='mu0', value=mu0)
    user_ids = np.unique(inter_feat['user_id'].values)
    item_ids = np.unique(inter_feat['item_id'].values)
    print(user_ids)
    print(item_ids)
    print(len(user_ids),len(item_ids))
    # inter_feat.to_csv(save_file_path,index=False)
    print('[KuaiShou] saved path: {}\nDone!'.format(save_file_path))
    print(inter_feat)
    print(inter_feat.columns)



if __name__ == "__main__":

    path = "/Users/wangzhenlei/Desktop/KuaiShou/datasets"
    data_generate(path)
